


__all__ = ['ReplicationPad1d', 'TokenEmbedding', 'PatchEmbedding', 'FlattenHead', 'ReprogrammingLayer', 'TimeLLM']


import math
from typing import Optional

import torch
import torch.nn as nn

import neuralforecast.losses.pytorch as losses

from ..common._base_model import BaseModel
from ..common._modules import RevIN
from ..losses.pytorch import MAE

try:
    from transformers import AutoConfig, AutoModel, AutoTokenizer

    IS_TRANSFORMERS_INSTALLED = True
except ImportError:
    IS_TRANSFORMERS_INSTALLED = False

import warnings


class ReplicationPad1d(nn.Module):
    """
    ReplicationPad1d
    """

    def __init__(self, padding):
        super(ReplicationPad1d, self).__init__()
        self.padding = padding

    def forward(self, input):
        replicate_padding = input[:, :, -1].unsqueeze(-1).repeat(1, 1, self.padding[-1])
        output = torch.cat([input, replicate_padding], dim=-1)
        return output


class TokenEmbedding(nn.Module):
    """
    TokenEmbedding
    """

    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= "1.5.0" else 2
        self.tokenConv = nn.Conv1d(
            in_channels=c_in,
            out_channels=d_model,
            kernel_size=3,
            padding=padding,
            padding_mode="circular",
            bias=False,
        )
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_in", nonlinearity="leaky_relu"
                )

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class PatchEmbedding(nn.Module):
    """
    PatchEmbedding
    """

    def __init__(self, d_model, patch_len, stride, dropout):
        super(PatchEmbedding, self).__init__()
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = ReplicationPad1d((0, stride))

        # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
        self.value_embedding = TokenEmbedding(patch_len, d_model)

        # Positional embedding
        # self.position_embedding = PositionalEmbedding(d_model)

        # Residual dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # do patching
        n_vars = x.shape[1]
        x = self.padding_patch_layer(x)
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        # Input encoding
        x = self.value_embedding(x)
        return self.dropout(x), n_vars


class FlattenHead(nn.Module):
    """
    FlattenHead
    """

    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x


class ReprogrammingLayer(nn.Module):
    """
    ReprogrammingLayer
    """

    def __init__(
        self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1
    ):
        super(ReprogrammingLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, target_embedding, source_embedding, value_embedding):
        B, L, _ = target_embedding.shape
        S, _ = source_embedding.shape
        H = self.n_heads

        target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
        source_embedding = self.key_projection(source_embedding).view(S, H, -1)
        value_embedding = self.value_projection(value_embedding).view(S, H, -1)

        out = self.reprogramming(target_embedding, source_embedding, value_embedding)

        out = out.reshape(B, L, -1)

        return self.out_projection(out)

    def reprogramming(self, target_embedding, source_embedding, value_embedding):
        B, L, H, E = target_embedding.shape

        scale = 1.0 / math.sqrt(E)

        scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)

        return reprogramming_embedding


class TimeLLM(BaseModel):
    """TimeLLM

    Time-LLM is a reprogramming framework to repurpose an off-the-shelf LLM for time series forecasting.

    It trains a reprogramming layer that translates the observed series into a language task. This is fed to the LLM and an output
    projection layer translates the output back to numerical predictions.

    Args:
        h (int): Forecast horizon.
        input_size (int): autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
        patch_len (int): length of patch. Default: 16
        stride (int): stride of patch. Default: 8
        d_ff (int): dimension of fcn. Default: 128
        top_k (int): top tokens to consider. Default: 5
        d_llm (int): hidden dimension of LLM. Default: 768 # LLama7b:4096; GPT2-small:768; BERT-base:768
        d_model (int): dimension of model. Default: 32
        n_heads (int): number of heads in attention layer. Default: 8
        enc_in (int): encoder input size. Default: 7
        dec_in (int): decoder input size. Default: 7
        llm (str): Path to pretrained LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2"
        llm_config (dict): Deprecated, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2"
        llm_tokenizer (str): Deprecated, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2"
        llm_num_hidden_layers (int): hidden layers in LLM. Default: 32
        llm_output_attention (bool): whether to output attention in encoder. Default: True
        llm_output_hidden_states (bool): whether to output hidden states. Default: True
        prompt_prefix (str): prompt to inform the LLM about the dataset. Default: None
        dropout (float): dropout rate. Default: 0.1
        stat_exog_list (list): static exogenous columns.
        hist_exog_list (list): historic exogenous columns.
        futr_exog_list (list): future exogenous columns.
        loss (PyTorch module): instantiated train loss class from [losses collection](./losses.pytorch).
        valid_loss (PyTorch module): instantiated valid loss class from [losses collection](./losses.pytorch).
        learning_rate (float): Learning rate between (0, 1). Default: 1e-3
        max_steps (int): maximum number of training steps. Default: 1000
        val_check_steps (int): Number of training steps between every validation loss check. Default: 100
        batch_size (int): number of different series in each batch. Default: 32
        valid_batch_size (int): number of different series in each validation and test batch, if None uses batch_size. Default: None
        windows_batch_size (int): number of windows to sample in each training batch, default uses all. Default: 1024
        inference_windows_batch_size (int): number of windows to sample in each inference batch. Default: 1024
        start_padding_enabled (bool): if True, the model will pad the time series with zeros at the beginning, by input size. Default: False
        training_data_availability_threshold (Union[float, List[float]]): minimum fraction of valid data points required for training windows. Single float applies to both insample and outsample; list of two floats specifies [insample_fraction, outsample_fraction]. Default 0.0 allows windows with only 1 valid data point (current behavior).
        step_size (int): step size between each window of temporal data. Default: 1
        num_lr_decays (int): Number of learning rate decays, evenly distributed across max_steps. Default: -1
        early_stop_patience_steps (int): Number of validation iterations before early stopping. Default: -1
        scaler_type (str): type of scaler for temporal inputs normalization see [temporal scalers](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/common/_scalers.py). Default: 'identity'
        random_seed (int): random_seed for pytorch initializer and numpy generators. Default: 1
        drop_last_loader (bool): if True `TimeSeriesDataLoader` drops last non-full batch. Default: False
        alias (str): optional,  Custom name of the model.
        optimizer (Subclass of 'torch.optim.Optimizer'): optional, user specified optimizer instead of the default choice (Adam).
        optimizer_kwargs (dict): optional, list of parameters used by the user specified `optimizer`.
        lr_scheduler (Subclass of 'torch.optim.lr_scheduler.LRScheduler'): optional, user specified lr_scheduler instead of the default choice (StepLR).
        lr_scheduler_kwargs (dict): optional, list of parameters used by the user specified `lr_scheduler`.
        dataloader_kwargs (dict): optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`.
        **trainer_kwargs (int):  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).

    References:
        - [Ming Jin, Shiyu Wang, Lintao Ma, Zhixuan Chu, James Y. Zhang, Xiaoming Shi, Pin-Yu Chen, Yuxuan Liang, Yuan-Fang Li, Shirui Pan, Qingsong Wen. "Time-LLM: Time Series Forecasting by Reprogramming Large Language Models"](https://arxiv.org/abs/2310.01728)

    """

    EXOGENOUS_FUTR = False
    EXOGENOUS_HIST = False
    EXOGENOUS_STAT = False
    MULTIVARIATE = False  # If the model produces multivariate forecasts (True) or univariate (False)
    RECURRENT = (
        False  # If the model produces forecasts recursively (True) or direct (False)
    )

    def __init__(
        self,
        h,
        input_size,
        patch_len: int = 16,
        stride: int = 8,
        d_ff: int = 128,
        top_k: int = 5,
        d_llm: int = 768,
        d_model: int = 32,
        n_heads: int = 8,
        enc_in: int = 7,
        dec_in: int = 7,
        llm=None,
        llm_config=None,
        llm_tokenizer=None,
        llm_num_hidden_layers=32,
        llm_output_attention: bool = True,
        llm_output_hidden_states: bool = True,
        prompt_prefix: Optional[str] = None,
        dropout: float = 0.1,
        stat_exog_list=None,
        hist_exog_list=None,
        futr_exog_list=None,
        loss=MAE(),
        valid_loss=None,
        learning_rate: float = 1e-4,
        max_steps: int = 5,
        val_check_steps: int = 100,
        batch_size: int = 32,
        valid_batch_size: Optional[int] = None,
        windows_batch_size: int = 1024,
        inference_windows_batch_size: int = 1024,
        start_padding_enabled: bool = False,
        training_data_availability_threshold=0.0,
        step_size: int = 1,
        num_lr_decays: int = 0,
        early_stop_patience_steps: int = -1,
        scaler_type: str = "identity",
        random_seed: int = 1,
        drop_last_loader: bool = False,
        alias: Optional[str] = None,
        optimizer=None,
        optimizer_kwargs=None,
        lr_scheduler=None,
        lr_scheduler_kwargs=None,
        dataloader_kwargs=None,
        **trainer_kwargs,
    ):
        super(TimeLLM, self).__init__(
            h=h,
            input_size=input_size,
            hist_exog_list=hist_exog_list,
            stat_exog_list=stat_exog_list,
            futr_exog_list=futr_exog_list,
            loss=loss,
            valid_loss=valid_loss,
            max_steps=max_steps,
            learning_rate=learning_rate,
            num_lr_decays=num_lr_decays,
            early_stop_patience_steps=early_stop_patience_steps,
            val_check_steps=val_check_steps,
            batch_size=batch_size,
            valid_batch_size=valid_batch_size,
            windows_batch_size=windows_batch_size,
            inference_windows_batch_size=inference_windows_batch_size,
            start_padding_enabled=start_padding_enabled,
            training_data_availability_threshold=training_data_availability_threshold,
            step_size=step_size,
            scaler_type=scaler_type,
            drop_last_loader=drop_last_loader,
            alias=alias,
            random_seed=random_seed,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            lr_scheduler=lr_scheduler,
            lr_scheduler_kwargs=lr_scheduler_kwargs,
            dataloader_kwargs=dataloader_kwargs,
            **trainer_kwargs,
        )
        if loss.outputsize_multiplier > 1:
            raise Exception(
                "TimeLLM only supports point loss functions (MAE, MSE, etc) as loss function."
            )

        if valid_loss is not None and not isinstance(valid_loss, losses.BasePointLoss):
            raise Exception(
                "TimeLLM only supports point loss functions (MAE, MSE, etc) as valid loss function."
            )

        # Architecture
        self.patch_len = patch_len
        self.stride = stride
        self.d_ff = d_ff
        self.top_k = top_k
        self.d_llm = d_llm
        self.d_model = d_model
        self.dropout = dropout
        self.n_heads = n_heads
        self.enc_in = enc_in
        self.dec_in = dec_in

        DEFAULT_MODEL = "openai-community/gpt2"

        if llm is None:
            if not IS_TRANSFORMERS_INSTALLED:
                raise ImportError(
                    "Please install `transformers` to use the default LLM."
                )

            print(f"Using {DEFAULT_MODEL} as default.")
            model_name = DEFAULT_MODEL
        else:
            model_name = llm

        if llm_config is not None or llm_tokenizer is not None:
            warnings.warn(
                "'llm_config' and 'llm_tokenizer' parameters are deprecated and will be ignored. "
                "The config and tokenizer will be automatically loaded from the specified model.",
                DeprecationWarning,
            )

        try:
            self.llm_config = AutoConfig.from_pretrained(model_name)
            self.llm = AutoModel.from_pretrained(model_name, config=self.llm_config)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
            print(f"Successfully loaded model: {model_name}")
        except EnvironmentError:
            print(
                f"Failed to load {model_name}. Loading the default model ({DEFAULT_MODEL})..."
            )
            self.llm_config = AutoConfig.from_pretrained(DEFAULT_MODEL)
            self.llm = AutoModel.from_pretrained(DEFAULT_MODEL, config=self.llm_config)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)

        self.llm_num_hidden_layers = llm_num_hidden_layers
        self.llm_output_attention = llm_output_attention
        self.llm_output_hidden_states = llm_output_hidden_states
        self.prompt_prefix = prompt_prefix

        if self.llm_tokenizer.eos_token:
            self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
        else:
            pad_token = "[PAD]"
            self.llm_tokenizer.add_special_tokens({"pad_token": pad_token})
            self.llm_tokenizer.pad_token = pad_token

        for param in self.llm.parameters():
            param.requires_grad = False

        self.patch_embedding = PatchEmbedding(
            self.d_model, self.patch_len, self.stride, self.dropout
        )

        self.word_embeddings = self.llm.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 1024
        self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)

        self.reprogramming_layer = ReprogrammingLayer(
            self.d_model, self.n_heads, self.d_ff, self.d_llm
        )

        self.patch_nums = int((input_size - self.patch_len) / self.stride + 2)
        self.head_nf = self.d_ff * self.patch_nums

        self.output_projection = FlattenHead(
            self.enc_in, self.head_nf, self.h, head_dropout=self.dropout
        )

        self.normalize_layers = RevIN(self.enc_in, affine=False)

    def forecast(self, x_enc):

        x_enc = self.normalize_layers(x_enc, "norm")

        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)
        trends = x_enc.diff(dim=1).sum(dim=1)

        prompt = []
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            prompt_ = (
                f"<|start_prompt|>{self.prompt_prefix}"
                f"Task description: forecast the next {str(self.h)} steps given the previous {str(self.input_size)} steps information; "
                "Input statistics: "
                f"min value {min_values_str}, "
                f"max value {max_values_str}, "
                f"median value {median_values_str}, "
                f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
                f"top 5 lags are : {lags_values_str}<|<end_prompt>|>"
            )

            prompt.append(prompt_)

        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

        prompt = self.llm_tokenizer(
            prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048
        ).input_ids
        prompt_embeddings = self.llm.get_input_embeddings()(
            prompt.to(x_enc.device)
        )  # (batch, prompt_token, dim)

        source_embeddings = self.mapping_layer(
            self.word_embeddings.permute(1, 0)
        ).permute(1, 0)

        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc.to(torch.float32))
        enc_out = self.reprogramming_layer(
            enc_out, source_embeddings, source_embeddings
        )
        llm_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
        dec_out = self.llm(inputs_embeds=llm_enc_out).last_hidden_state
        dec_out = dec_out[:, :, : self.d_ff]

        dec_out = torch.reshape(
            dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1])
        )
        dec_out = dec_out.permute(0, 1, 3, 2).contiguous()

        dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums :])
        dec_out = dec_out.permute(0, 2, 1).contiguous()

        dec_out = self.normalize_layers(dec_out, "denorm")

        return dec_out

    def calcute_lags(self, x_enc):
        q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)
        mean_value = torch.mean(corr, dim=1)
        _, lags = torch.topk(mean_value, self.top_k, dim=-1)
        return lags

    def forward(self, windows_batch):
        x = windows_batch["insample_y"]

        y_pred = self.forecast(x)
        y_pred = y_pred[:, -self.h :, :]

        return y_pred
